import torch
from torch import nn
from torch.nn import functional as F

X = torch.randn(32, 3, 480, 640)
bn = nn.BatchNorm2d(num_features = 3)
y = bn(X)
print(y.shape)

